import torch
import triton
import triton.language as tl

# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
#                   for BLOCK_N in [2048, 4096, 8192]
#                   for ns in [1, 2, 4]
#                   for nw in [1, 2, 4, 8, 16]],
#                   key=['N'])
@triton.jit
def _selective_log_softmax_kernel(LOGITS,
                                  INPUT_IDS,
                                  LOG_P,
                                  MASK,
                                  TEMPERATURE,
                                  stride_input_ids_b,
                                  L: tl.constexpr,
                                  N: tl.constexpr,
                                  BLOCK_N:tl.constexpr=4096):
    off_b = tl.program_id(0).cast(tl.int64)
    off_l = tl.program_id(1).cast(tl.int64)

    LOGITS += off_b * (L+1) * N + off_l * N
    INPUT_IDS += off_b * stride_input_ids_b + off_l
    LOG_P += off_b * L + off_l

    
    if MASK is not None:
        MASK += off_b * stride_input_ids_b + off_l
        not_skip = tl.load(MASK)
        if not_skip == 0:
            return

    m_i = float('-inf')
    l_i = 0. 
    for start in range(0, N, BLOCK_N):
        cols = start + tl.arange(0, BLOCK_N)
        logits = tl.load(LOGITS + cols, mask=cols < N, other=float('-inf')).to(tl.float32) / TEMPERATURE
        new_m_i = tl.maximum(m_i, tl.max(logits))
        alpha = tl.exp(m_i - new_m_i)
        l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
        m_i = new_m_i
    lse = m_i + tl.log(l_i)

    ids = tl.load(INPUT_IDS)
    x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
    logp = x - lse
    tl.store(LOG_P, logp)
    

# 用于计算old_logp和ref_logp，不需要梯度，只写前向即可
def fused_selective_log_softmax(logits:torch.Tensor, input_ids:torch.Tensor, temperature:float=0.9, mask=None):
    assert logits.is_contiguous()
    B, L_ADD_1, N = logits.shape
    L = L_ADD_1 - 1
    input_ids = input_ids[:, -L:]
    if mask is not None:
        mask = mask[:, -L:]
    log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
    kwargs = {"BLOCK_N":2048, "num_stages":4, "num_warps":1}
    _selective_log_softmax_kernel[(B, L)](logits,
                                          input_ids,
                                          log_p,
                                          mask,
                                          temperature,
                                          input_ids.stride(0),
                                          L,
                                          N,
                                          **kwargs
                                          )
    return log_p





# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
#                   for BLOCK_N in [2048, 4096, 8192]
#                   for ns in [1, 2, 4]
#                   for nw in [1, 2, 4, 8, 16]],
#                   key=['N'])
@triton.jit
def _grpo_loss_fwd_kernel(LOGITS,
                         OLD_LOGP,
                         REF_LOGP,
                        INPUT_IDS,
                        COMPLETION_MASK,
                        ADVANTAGES,
                        LOSS,
                        LSE,
                        KL,
                        IS_CLIPPED,
                        TEMPERATURE,
                        BETA:tl.constexpr,
                        EPS_LOW,
                        EPS_HIGH,
                        L: tl.constexpr,
                        N: tl.constexpr,
                        BLOCK_N:tl.constexpr=4096):
    off_b = tl.program_id(0).cast(tl.int64)
    off_l = tl.program_id(1).cast(tl.int64)

    if COMPLETION_MASK is not None:
        COMPLETION_MASK += off_b * L + off_l
        not_skip = tl.load(COMPLETION_MASK)
        if not_skip == 0:
            return
        
    LOGITS += off_b * (L+1) * N + off_l * N
    INPUT_IDS += off_b * L + off_l
    ADVANTAGES += off_b
    LOSS += off_b * L + off_l
    LSE += off_b * L + off_l
    IS_CLIPPED += off_b * L + off_l
        
    m_i = float('-inf')
    l_i = 0. 
    for start in range(0, N, BLOCK_N):
        cols = start + tl.arange(0, BLOCK_N)
        logits = tl.load(LOGITS + cols, mask=cols < N, other=float('-inf')).to(tl.float32) / TEMPERATURE
        new_m_i = tl.maximum(m_i, tl.max(logits))
        alpha = tl.exp(m_i - new_m_i)
        l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
        m_i = new_m_i
    lse = m_i + tl.log(l_i)

    idx = tl.load(INPUT_IDS)
    x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
    logp = x - lse
    if OLD_LOGP is None:
        old_logp = logp
    else:
        OLD_LOGP += off_b * L + off_l
        old_logp = tl.load(OLD_LOGP).to(tl.float32)
    coef_1 = tl.exp(logp - old_logp)
    coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
    advantage = tl.load(ADVANTAGES).to(tl.float32)
    per_token_loss1 = coef_1 * advantage
    per_token_loss2 = coef_2 * advantage
    per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
    is_clipped = per_token_loss1 < per_token_loss2

    if BETA != 0.0:
        REF_LOGP += off_b * L + off_l
        KL += off_b * L + off_l
        ref_logp = tl.load(REF_LOGP).to(tl.float32)
        kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
        per_token_loss += BETA * kl
        tl.store(KL, kl)
        
    tl.store(LOSS, per_token_loss)
    tl.store(LSE, lse)
    tl.store(IS_CLIPPED, is_clipped)
    

# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
#                   for BLOCK_N in [2048, 4096, 8192]
#                   for ns in [1, 2, 4]
#                   for nw in [1, 2, 4, 8, 16]],
#                   key=['N'])  
@triton.jit
def _grpo_loss_bwd_kernel(DLOSS,
                        DLOGITS,
                        LOGITS,
                         OLD_LOGP,
                         REF_LOGP,
                        INPUT_IDS,
                        ADVANTAGES,
                        COMPLETION_MASK,
                        LSE,
                        TEMPERATURE,
                        BETA:tl.constexpr,
                        EPS_LOW,
                        EPS_HIGH,
                        loss_stride0,
                        loss_stride1,
                        L: tl.constexpr,
                        N: tl.constexpr,
                        BLOCK_N:tl.constexpr=4096):

    off_b = tl.program_id(0).cast(tl.int64)
    off_l = tl.program_id(1).cast(tl.int64)

    
    DLOGITS += off_b * (L+1) * N + off_l * N
    if COMPLETION_MASK is not None:
        COMPLETION_MASK += off_b * L + off_l
        not_skip = tl.load(COMPLETION_MASK)
        if not_skip == 0:
            for start in range(0, N, BLOCK_N):
                cols = tl.arange(0, BLOCK_N) + start
                tl.store(DLOGITS+cols, 0., mask=cols<N)
            return
    
    LOGITS += off_b * (L+1) * N + off_l * N
    DLOSS += off_b * loss_stride0 + off_l * loss_stride1
    INPUT_IDS += off_b * L + off_l
    ADVANTAGES += off_b
    LSE += off_b * L + off_l

    dloss = tl.load(DLOSS).to(tl.float32)
    lse = tl.load(LSE).to(tl.float32)

    idx = tl.load(INPUT_IDS)
    x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
    logp = x - lse
    if OLD_LOGP is None:
        old_logp = logp
    else:
        OLD_LOGP += off_b * L + off_l
        old_logp = tl.load(OLD_LOGP).to(tl.float32)
    coef_1 = tl.exp(logp - old_logp)
    coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
    advantage = tl.load(ADVANTAGES).to(tl.float32)
    per_token_loss1 = coef_1 * advantage
    per_token_loss2 = coef_2 * advantage
    mask = per_token_loss2 >= per_token_loss1

    dlogp = -per_token_loss1 * mask
    if BETA != 0.0:
        REF_LOGP += off_b * L + off_l
        ref_logp = tl.load(REF_LOGP).to(tl.float32)
        dlogp += BETA * (1 - tl.exp(ref_logp - logp))
    
    # REF_LOGP += off_b * L + off_l
    # ref_logp = tl.load(REF_LOGP).to(tl.float32)
    # dlogp += BETA * (1 - tl.exp(ref_logp - logp))
    
    dlogp = dlogp * dloss / TEMPERATURE
    # 感谢游神这行最伟大的代码！
    tl.debug_barrier()
    for start_n in tl.range(0, N, BLOCK_N):
        cols = start_n + tl.arange(0, BLOCK_N)
        logits = tl.load(LOGITS+cols, mask=cols < N, other=0.).to(tl.float32) / TEMPERATURE
        probs = tl.exp(logits - lse)
        dlogits = tl.where(cols==idx, 1-probs, -probs) * dlogp
        tl.store(DLOGITS+cols, dlogits, mask=cols < N)
        

class GrpoLoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, temperature, beta, eps_low, eps_high, inplace):
        assert logits.is_contiguous() and completion_ids.is_contiguous()
        assert old_logp is None or old_logp.is_contiguous()
        assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
        
        B, L_ADD_1, N = logits.shape
        L = L_ADD_1 - 1

        if completion_mask is not None:
            assert completion_mask.is_contiguous()

        loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
        lse = torch.zeros_like(loss)
        is_clipped = torch.zeros_like(loss)
        kl = torch.zeros_like(loss) if beta != 0.0 else None
        kwargs = {"BLOCK_N":2048, "num_stages":2, "num_warps":1}
        _grpo_loss_fwd_kernel[(B, L)](logits,
                                     old_logp,
                                     ref_logp,
                                     completion_ids,
                                     completion_mask,
                                     advantages,
                                     loss,
                                     lse,
                                     kl,
                                     is_clipped,
                                     temperature,
                                     beta,
                                     eps_low,
                                     eps_high,
                                     L,
                                     N,
                                     **kwargs
                                     )
        ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
        ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
        # return loss
        return loss, kl, is_clipped
    
    @staticmethod
    def backward(ctx, *args):
        dloss = args[0]
        # print(dloss.shape)
        logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
        temperature, beta, eps_low, eps_high, inplace = ctx.infos
        B, L_ADD_1, N = logits.shape
        L = L_ADD_1 - 1
        dlogits = logits if inplace else torch.empty_like(logits)
        # dlogits = torch.empty_like(logits)
        kwargs = {"BLOCK_N":4096, "num_stages":1, "num_warps":16}
        _grpo_loss_bwd_kernel[(B, L)](dloss,
                                      dlogits,
                                      logits,
                                      old_logp,
                                      ref_logp,
                                      completion_ids,
                                      advantages,
                                      completion_mask,
                                      lse,
                                      temperature,
                                      beta,
                                      eps_low,
                                      eps_high,
                                      *dloss.stride(),
                                      L,
                                      N,
                                      **kwargs
                                        )
        dlogits[:, -1, :] = 0
        return dlogits, None,None,None,None,None,None,None,None,None,None

def triton_grpo_loss(logits, 
                     old_logp, 
                     ref_logp, 
                     completion_ids, 
                     advantages, 
                     completion_mask=None, 
                     temperature=0.9, 
                     beta=0.04, 
                     eps_low=0.2, 
                     eps_high=0.4, 
                     inplace=True):
    assert logits is not None and completion_ids is not None and advantages is not None, "must provide logits、completion_ids and advantages"

    return GrpoLoss.apply(logits, 
                          old_logp, 
                          ref_logp, 
                          completion_ids, 
                          advantages, 
                          completion_mask, 
                          temperature, 
                          beta, 
                          eps_low, 
                          eps_high,
                          inplace)

